# -*- coding:utf-8 -*- 
# Utils module: useful functions to build exploits
from ropgenerator.semantic.Engine import search, LMAX
from ropgenerator.Constraints import Constraint, RegsNotModified, Assertion, Chainable, StackPointerIncrement
from ropgenerator.semantic.ROPChains import ROPChain
from ropgenerator.Database import QueryType, DBAllPossibleWrites
from ropgenerator.exploit.Scanner import getFunctionAddress, findBytes
from ropgenerator.IO import verbose, string_bold, string_special, string_ropg, error
import itertools
import ropgenerator.Architecture as Arch 

########################
# ROP-Chains functions #
########################

STORE_CONSTANT_ADDRESS_LMAX = 80
def store_constant_address(qtype, cst_addr, value, constraint=None, assertion=None, clmax=None, optimizeLen=False):
    """
    Does a XXXtoMEM kind of query BUT the memory address is 
    a simple constant ! 
    
    Expected qtypes are only XXXtoMEM
    cst_addr is the store address
    value is the value to store, a single cst or a couple (reg,cst)
    """
    if( clmax is None ):
        clmax = STORE_CONSTANT_ADDRESS_LMAX
    elif( clmax <= 0 ):
        return None
    
    if( constraint is None ):
        constr = Constraint()
    else:
        constr = constraint
    if( assertion is None ):
        a = Assertion()
    else:
        a = assertion
    
    # Tranform the query type
    if( qtype == QueryType.CSTtoMEM ):
        qtype2 = QueryType.CSTtoREG 
    elif( qtype == QueryType.REGtoMEM):
        qtype2 = QueryType.REGtoREG
    elif(qtype == QueryType.MEMtoREG):
        qtype2 = QueryType.MEMtoREG
    else:
        raise Exception("Query type {} should not appear in this function!".format(qtype))
        
    tried_values = []
    tried_cst_addr = []
    best = None # If optimizeLen
    shortest = clmax # Shortest ROPChain found if optimizeLen ;) 
    
    for ((addr_reg, addr_cst), (reg,cst), gadget) in \
    sorted(DBAllPossibleWrites(constr.add(Chainable(ret=True)), a), \
    key=lambda x: 0 if (x[1] == value) else 1) :
        # DOn't use rip or rsp... 
        if( reg == Arch.ipNum() or reg == Arch.spNum()\
            or addr_reg == Arch.ipNum() or addr_reg == Arch.spNum()):
            continue
        res = None
        # Check if directly the registers we want to write ;) 
        value_is_reg = False
        value_to_reg = []
        addr_to_reg = []
        if( (reg,cst) == value ):
            value_to_reg = [ROPChain()]
            value_is_reg = True
        
        # adapt value
        if(not isinstance(value, tuple) ):
            adjusted_value = value - cst
        else:
            adjusted_value = (value[0], value[1]-cst)
        adjusted_cst_addr = cst_addr - addr_cst
        # Get spInc
        gadget_paddingLen = (gadget.spInc/Arch.octets())-1
        # Check if tried before 
        if( (reg,cst) in tried_values ):
            continue
        elif( (addr_reg, addr_cst) in tried_cst_addr):
            continue
        ### Try to do reg first then addr_reg 
        # Try to put the value into reg
        clmax2 = shortest - gadget_paddingLen - 1
        if( not value_is_reg):
            value_to_reg = search(qtype2, reg, adjusted_value, constr, a, clmax=clmax2, n=1, optimizeLen=optimizeLen )
            if( not value_to_reg ):
                tried_values.append((reg,cst))
                continue
            else:
                clmax2 = clmax2 - len(value_to_reg[0])
        # Try to put the cst_addr in addr_reg
        addr_to_reg = search(QueryType.CSTtoREG, addr_reg, adjusted_cst_addr, constr.add(RegsNotModified([reg])), a, clmax=clmax2, n=1, optimizeLen=optimizeLen)
        if( addr_to_reg ):
            # If we found a solution 
            # Combine them and return 
            # Padd the gadget 
            res = value_to_reg[0].addChain(addr_to_reg[0]).addGadget(gadget)
            if( gadget.spInc > 0 ):
                padding_value = constr.getValidPadding(Arch.octets())
                res = res.addPadding(padding_value, n=(gadget.spInc/Arch.octets())-1)
            if( optimizeLen ):
                if( best ):
                    best = min(best,res)
                else:
                    best = res
                shortest = len(best)
            else:
                return res

        ### Try to do addr_reg first and then reg 
        clmax2 = shortest - gadget_paddingLen - 1 
        # Try to put the cst_addr in addr_reg
        addr_to_reg = search(QueryType.CSTtoREG, addr_reg, adjusted_cst_addr, constr, a, clmax=clmax2, n=1, optimizeLen=optimizeLen)
        if( not addr_to_reg ):
            tried_cst_addr.append((addr_reg, addr_cst))
            continue
        else:
            clmax2 = clmax2 - len(addr_to_reg[0])
        # Try to put the value into reg
        if( not value_is_reg ):
            value_to_reg = search(qtype2, reg, adjusted_value, constr.add(RegsNotModified([addr_reg])), a, clmax=clmax2, n=1, optimizeLen=optimizeLen )
        if( value_to_reg ):
            # If we found a solution 
            # Combine them and return 
            # Padd the gadget 
            res = addr_to_reg[0].addChain(value_to_reg[0]).addGadget(gadget)
            if( gadget.spInc > 0 ):
                padding_value = constr.getValidPadding(Arch.octets())
                res = res.addPadding(padding_value, n=(gadget.spInc/Arch.octets())-1)
            if( optimizeLen ):
                if( best ):
                    best = min(best,res)
                else:
                    best = res
                shortest = len(best)
            else:
                return res
        
        # 5 = two pops for addr_reg and reg + 1 for the write gadget 
        # So since 5 is the shortest possible with two pops we can return 
        # We can have < 5 if reg is already equal to 'value' argument
        # But we try this case first (see sorted()) when getting possibleWrites ;) 
        if( ((not optimizeLen) or (not value_is_reg)) and (not best is None) and len(best) <= 5 ):
            return best
        elif( optimizeLen and (not best is None) and len(best) <= 3 ):
            return best
    return best 


########################
#   Parsing functions  #
########################

def parseFunction(string):
    def seek(char, string):
        for i in range(0, len(string)):
            if string[i] == char:
                return (string[:i], i)
        return ([],-1)
        
    if( not string ):
        error("Missing fuction to call")
        return (None, None)
    
    # COmpress the string
    string = "".join(string.split())
    
    # Get the function name 
    (funcName, index) = seek("(", string)
    if( not funcName ):
        error("Invalid function call")
        return (None, None)
    rest = string[index+1:]
    args = []
    arg = ''
    i = 0
    end = False
    while(i < len(rest)):
        c = rest[i]
        # No args
        if( c == ")" and not args):
            end = True
            i += 1
        # String
        elif( c == '"' or c == "'" ):
            (s, index)= seek(c, rest[i+1:])
            if( s == 0 ):
                error("Error. Empty string argument ?")
                return (None, None)
            elif( not s ):
                error("Missing closing {} for string".format(c))
                return (None, None)
            # Parse the string 
            j = 0
            s = str(s)
            parsed_string = ""
            while( j < len(s)):
                if( s[j:j+2] == "\\x" ):
                    if( j + 3 < len(s)):
                        try:
                            char = int(s[j+2:j+4], 16) 
                        except:
                            error("Invalid byte: '{}'".format(s[j:j+4]))
                            return (None, None)
                    else:
                        error("Invalid byte: '{}'".format(s[j:j+4]))
                        return (None, None)
                    parsed_string += chr(char)
                    j+= 4
                else:
                    parsed_string += s[j]
                    j += 1
            args.append(str(parsed_string))
            
            i += index +2
            if( i >= len(rest)):
                error("Error. Missing ')'")
                return (None, None)
            elif( rest[i] == ')' ):
                end = True
                i += 1
            elif( rest[i] == "," ):
                i += 1
            else:
                error("Error. Missing ',' or ')' after string")
                return (None, None)
        # Constant
        else:
            # Get the constant 
            arg = ''
            ok = False
            for j in range(i, len(rest)):
                if( rest[j] == ")" ):
                    end = True
                    ok = True
                    break
                elif( rest[j] == ','):
                    ok = True
                    break
                else:
                    arg += rest[j]
            if( not ok ):
                error("Missing ')' after argument")
                return (None, None)
            if( (not arg) and args):
                error("Missing argument")
                return (None, None)
            # Convert to int 
            try:
                value = int(arg)
            except:
                try:
                    value = int(arg, 16)
                except:
                    try:
                        value = int(arg, 2)
                    except:
                        error("Invalid operand: " + arg )
                        return (None, None)
            args.append(value)
            i = j+1
        if( end):
            break
    
    if( not end ):
        error("Error. Missing ')'")
        return    (None, None)     
    if( i < len(rest)):
        error("Error. Extra argument: {}".format(rest[i:]))
        return (None, None)

    # str() to set its type to str ;) 
    return (str(funcName), args)
    
    
def parse_bad_bytes(string):
    """
    Parses a bad bytes string into a list of bad bytes
    Input: a string of format like "00,0A,FF,32,C7"
    Ouput if valid string : (True, list) where list = 
        ['00', '0a', 'ff', '32', 'c7'] (separate them in individual strings
        and force lower case)
    Output if invalid string (False, error_message)
    """
    hex_chars = '0123456789abcdefABCDEF'
    i = 0
    bad_bytes = []
    user_bad_bytes = [b.lower() for b in string.split(',')]
    for user_bad_byte in user_bad_bytes:
        if( not user_bad_byte ):
            return (False, "Error. Missing bad byte after ','")
        elif( len(user_bad_byte) != 2 ):
            return (False, "Error. '{}' is not a valid byte".format(user_bad_byte))
        elif( not ((user_bad_byte[i] in hex_chars) and (user_bad_byte[i+1] in hex_chars))):
            return (False, "Error. '{}' is not a valid byte".format(user_bad_byte))
        else:
            bad_bytes.append(user_bad_byte)
    return (True, bad_bytes)
    
def parse_keep_regs(string):
    """
    Parses a 'keep registers' string into a list of register uids
    Input: a string of format like "rax,rcx,rdi"
    Output if valid string (True, list) where list = 
        [1, 3, 4] (R1 is rax, R3 is RCX, ... )
    Output if invalid string (False, error_message)
    """
    user_keep_regs = string.split(',')
    keep_regs = set()
    for reg in user_keep_regs:
        if( reg in Arch.regNameToNum ):
            keep_regs.add(Arch.n2r(reg))
        else:
            return (False, "Error. '{}' is not a valid register".format(reg))
    return (True, list(keep_regs))
